{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pL--_KGdYoBz"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "uBDvXpYzYnGj"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HQzaEQuJiW_d"
      },
      "source": [
        "# TFRecord 和 tf.Example\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://tensorflow.google.cn/tutorials/load_data/tfrecord\" class=\"\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\" class=\"\">在 TensorFlow.org 上查看 </a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/tutorials/load_data/tfrecord.ipynb\" class=\"\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\" class=\"\">在 Google Colab 中运行</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/tutorials/load_data/tfrecord.ipynb\" class=\"\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\" class=\"\">在  GitHub 上查看源代码</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/tutorials/load_data/tfrecord.ipynb\" class=\"\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\" class=\"\">下载笔记本</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3pkUd_9IZCFO"
      },
      "source": [
        "为了高效地读取数据，比较有帮助的一种做法是对数据进行序列化并将其存储在一组可线性读取的文件（每个文件 100-200MB）中。这尤其适用于通过网络进行流式传输的数据。这种做法对缓冲任何数据预处理也十分有用。\n",
        "\n",
        "TFRecord 格式是一种用于存储二进制记录序列的简单格式。\n",
        "\n",
        "[协议缓冲区](https://developers.google.com/protocol-buffers/)是一个跨平台、跨语言的库，用于高效地序列化结构化数据。\n",
        "\n",
        "协议消息由 `.proto` 文件定义，这通常是了解消息类型最简单的方法。\n",
        "\n",
        "`tf.Example` 消息（或 protobuf）是一种灵活的消息类型，表示 `{\"string\": value}` 映射。它专为 TensorFlow 而设计，并被用于 [TFX](https://tensorflow.google.cn/tfx/) 等高级 API。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ac83J0QxjhFt"
      },
      "source": [
        "本笔记本将演示如何创建、解析和使用 `tf.Example` 消息，以及如何在 `.tfrecord` 文件之间对 `tf.Example` 消息进行序列化、写入和读取。\n",
        "\n",
        "注：这些结构虽然有用，但并不是强制的。您无需转换现有代码即可使用 TFRecord，除非您正在使用 [tf.data](https://tensorflow.google.cn/guide/datasets) 且读取数据仍是训练的瓶颈。有关数据集性能的提示，请参阅[数据输入流水线性能](https://tensorflow.google.cn/guide/performance/datasets)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WkRreBf1eDVc"
      },
      "source": [
        "## 设置"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ja7sezsmnXph"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "import numpy as np\n",
        "import IPython.display as display"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e5Kq88ccUWQV"
      },
      "source": [
        "## `tf.Example`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VrdQHgvNijTi"
      },
      "source": [
        "### `tf.Example` 的数据类型"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lZw57Qrn4CTE"
      },
      "source": [
        "从根本上讲，`tf.Example` 是 `{\"string\": tf.train.Feature}` 映射。\n",
        "\n",
        "`tf.train.Feature` 消息类型可以接受以下三种类型（请参阅 [`.proto` 文件](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/feature.proto)）。大多数其他通用类型也可以强制转换成下面的其中一种：\n",
        "\n",
        "1. `tf.train.BytesList`（可强制转换自以下类型）\n",
        "\n",
        "- `string`\n",
        "- `byte`\n",
        "\n",
        "1. `tf.train.FloatList`（可强制转换自以下类型）\n",
        "\n",
        "- `float` (`float32`)\n",
        "- `double` (`float64`)\n",
        "\n",
        "1. `tf.train.Int64List`（可强制转换自以下类型）\n",
        "\n",
        "- `bool`\n",
        "- `enum`\n",
        "- `int32`\n",
        "- `uint32`\n",
        "- `int64`\n",
        "- `uint64`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_e3g9ExathXP"
      },
      "source": [
        "为了将标准 TensorFlow 类型转换为兼容 `tf.Example` 的 `tf.train.Feature`，可以使用下面的快捷函数。请注意，每个函数会接受标量输入值并返回包含上述三种 `list` 类型之一的 `tf.train.Feature`："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mbsPOUpVtYxA"
      },
      "outputs": [],
      "source": [
        "# The following functions can be used to convert a value to a type compatible\n",
        "# with tf.Example.\n",
        "\n",
        "def _bytes_feature(value):\n",
        "  \"\"\"Returns a bytes_list from a string / byte.\"\"\"\n",
        "  if isinstance(value, type(tf.constant(0))):\n",
        "    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.\n",
        "  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))\n",
        "\n",
        "def _float_feature(value):\n",
        "  \"\"\"Returns a float_list from a float / double.\"\"\"\n",
        "  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))\n",
        "\n",
        "def _int64_feature(value):\n",
        "  \"\"\"Returns an int64_list from a bool / enum / int / uint.\"\"\"\n",
        "  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Wst0v9O8hgzy"
      },
      "source": [
        "注：为了简单起见，本示例仅使用标量输入。要处理非标量特征，最简单的方法是使用 `tf.io.serialize_tensor` 将张量转换为二进制字符串。在 TensorFlow 中，字符串是标量。使用 `tf.io.parse_tensor` 可将二进制字符串转换回张量。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vsMbkkC8xxtB"
      },
      "source": [
        "下面是有关这些函数如何工作的一些示例。请注意不同的输入类型和标准化的输出类型。如果函数的输入类型与上述可强制转换的类型均不匹配，则该函数将引发异常（例如，`_int64_feature(1.0)` 将出错，因为 `1.0` 是浮点数，应该用于 `_float_feature` 函数）："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hZzyLGr0u73y"
      },
      "outputs": [],
      "source": [
        "print(_bytes_feature(b'test_string'))\n",
        "print(_bytes_feature(u'test_bytes'.encode('utf-8')))\n",
        "\n",
        "print(_float_feature(np.exp(1)))\n",
        "\n",
        "print(_int64_feature(True))\n",
        "print(_int64_feature(1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nj1qpfQU5qmi"
      },
      "source": [
        "可以使用 `.SerializeToString` 方法将所有协议消息序列化为二进制字符串："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5afZkORT5pjm"
      },
      "outputs": [],
      "source": [
        "feature = _float_feature(np.exp(1))\n",
        "\n",
        "feature.SerializeToString()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "laKnw9F3hL-W"
      },
      "source": [
        "### 创建 `tf.Example` 消息"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b_MEnhxchQPC"
      },
      "source": [
        "假设您要根据现有数据创建 `tf.Example` 消息。在实践中，数据集可能来自任何地方，但是从单个观测值创建 `tf.Example` 消息的过程相同：\n",
        "\n",
        "1. 在每个观测结果中，需要使用上述其中一种函数，将每个值转换为包含三种兼容类型之一的 `tf.train.Feature`。\n",
        "\n",
        "2. 创建一个从特征名称字符串到第 1 步中生成的编码特征值的映射（字典）。\n",
        "\n",
        "3. 将第 2 步中生成的映射转换为 [`Features` 消息](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/feature.proto#L85)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4EgFQ2uHtchc"
      },
      "source": [
        "在此笔记本中，您将使用 NumPy 创建一个数据集。\n",
        "\n",
        "此数据集将具有 4 个特征：\n",
        "\n",
        "- 具有相等 `False` 或 `True` 概率的布尔特征\n",
        "- 从 `[0, 5]` 均匀随机选择的整数特征\n",
        "- 通过将整数特征作为索引从字符串表生成的字符串特征\n",
        "- 来自标准正态分布的浮点特征\n",
        "\n",
        "请思考一个样本，其中包含来自上述每个分布的 10,000 个独立且分布相同的观测值："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CnrguFAy3YQv"
      },
      "outputs": [],
      "source": [
        "# The number of observations in the dataset.\n",
        "n_observations = int(1e4)\n",
        "\n",
        "# Boolean feature, encoded as False or True.\n",
        "feature0 = np.random.choice([False, True], n_observations)\n",
        "\n",
        "# Integer feature, random from 0 to 4.\n",
        "feature1 = np.random.randint(0, 5, n_observations)\n",
        "\n",
        "# String feature\n",
        "strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])\n",
        "feature2 = strings[feature1]\n",
        "\n",
        "# Float feature, from a standard normal distribution\n",
        "feature3 = np.random.randn(n_observations)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aGrscehJr7Jd"
      },
      "source": [
        "您可以使用 `_bytes_feature`、`_float_feature` 或 `_int64_feature` 将下面的每个特征强制转换为兼容 `tf.Example` 的类型。然后，可以通过下面的已编码特征创建 `tf.Example` 消息："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RTCS49Ij_kUw"
      },
      "outputs": [],
      "source": [
        "def serialize_example(feature0, feature1, feature2, feature3):\n",
        "  \"\"\"\n",
        "  Creates a tf.Example message ready to be written to a file.\n",
        "  \"\"\"\n",
        "  # Create a dictionary mapping the feature name to the tf.Example-compatible\n",
        "  # data type.\n",
        "  feature = {\n",
        "      'feature0': _int64_feature(feature0),\n",
        "      'feature1': _int64_feature(feature1),\n",
        "      'feature2': _bytes_feature(feature2),\n",
        "      'feature3': _float_feature(feature3),\n",
        "  }\n",
        "\n",
        "  # Create a Features message using tf.train.Example.\n",
        "\n",
        "  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))\n",
        "  return example_proto.SerializeToString()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XftzX9CN_uGT"
      },
      "source": [
        "例如，假设您从数据集中获得了一个观测值 `[False, 4, bytes('goat'), 0.9876]`。您可以使用 `create_message()` 创建和打印此观测值的 `tf.Example` 消息。如上所述，每个观测值将被写为一条 `Features` 消息。请注意，`tf.Example` [消息](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto#L88)只是 `Features` 消息外围的包装器："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N8BtSx2RjYcb"
      },
      "outputs": [],
      "source": [
        "# This is an example observation from the dataset.\n",
        "\n",
        "example_observation = []\n",
        "\n",
        "serialized_example = serialize_example(False, 4, b'goat', 0.9876)\n",
        "serialized_example"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_pbGATlG6u-4"
      },
      "source": [
        "要解码消息，请使用 `tf.train.Example.FromString` 方法。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dGim-mEm6vit"
      },
      "outputs": [],
      "source": [
        "example_proto = tf.train.Example.FromString(serialized_example)\n",
        "example_proto"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o6qxofy89obI"
      },
      "source": [
        "## TFRecords 格式详细信息\n",
        "\n",
        "TFRecord 文件包含一系列记录。该文件只能按顺序读取。\n",
        "\n",
        "每条记录包含一个字节字符串（用于数据有效负载），外加数据长度，以及用于完整性检查的 CRC32C（使用 Castagnoli 多项式的 32 位 CRC）哈希值。\n",
        "\n",
        "每条记录会存储为以下格式：\n",
        "\n",
        "```\n",
        "uint64 length uint32 masked_crc32_of_length byte   data[length] uint32 masked_crc32_of_data\n",
        "```\n",
        "\n",
        "将记录连接起来以生成文件。[此处](https://en.wikipedia.org/wiki/Cyclic_redundancy_check)对 CRC 进行了说明，且 CRC 的掩码为：\n",
        "\n",
        "```\n",
        "masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul\n",
        "```\n",
        "\n",
        "注：不需要在 TFRecord 文件中使用 `tf.Example`。`tf.Example` 只是将字典序列化为字节字符串的一种方法。文本行、编码的图像数据，或序列化的张量（使用 `tf.io.serialize_tensor`，或在加载时使用 `tf.io.parse_tensor`）。有关更多选项，请参阅 `tf.io` 模块。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y-Hjmee-fbLH"
      },
      "source": [
        "## 使用 `tf.data` 的 TFRecord 文件"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GmehkCCT81Ez"
      },
      "source": [
        "`tf.data` 模块还提供用于在 TensorFlow 中读取和写入数据的工具。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1FISEuz8ubu3"
      },
      "source": [
        "### 写入 TFRecord 文件\n",
        "\n",
        "要将数据放入数据集中，最简单的方式是使用 `from_tensor_slices` 方法。\n",
        "\n",
        "若应用于数组，将返回标量数据集："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mXeaukvwu5_-"
      },
      "outputs": [],
      "source": [
        "tf.data.Dataset.from_tensor_slices(feature1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f-q0VKyZvcad"
      },
      "source": [
        "若应用于数组的元组，将返回元组的数据集："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H5sWyu1kxnvg"
      },
      "outputs": [],
      "source": [
        "features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3))\n",
        "features_dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m1C-t71Nywze"
      },
      "outputs": [],
      "source": [
        "# Use `take(1)` to only pull one example from the dataset.\n",
        "for f0,f1,f2,f3 in features_dataset.take(1):\n",
        "  print(f0)\n",
        "  print(f1)\n",
        "  print(f2)\n",
        "  print(f3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mhIe63awyZYd"
      },
      "source": [
        "使用 `tf.data.Dataset.map` 方法可将函数应用于 `Dataset` 的每个元素。\n",
        "\n",
        "映射函数必须在 TensorFlow 计算图模式下进行运算（它必须在 `tf.Tensors` 上运算并返回）。可以使用 `tf.py_function` 包装非张量函数（如 `serialize_example`）以使其兼容。\n",
        "\n",
        "使用 `tf.py_function` 需要指定形状和类型信息，否则它将不可用："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "apB5KYrJzjPI"
      },
      "outputs": [],
      "source": [
        "def tf_serialize_example(f0,f1,f2,f3):\n",
        "  tf_string = tf.py_function(\n",
        "    serialize_example,\n",
        "    (f0,f1,f2,f3),  # pass these args to the above function.\n",
        "    tf.string)      # the return type is `tf.string`.\n",
        "  return tf.reshape(tf_string, ()) # The result is a scalar"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lHFjW4u4Npz9"
      },
      "outputs": [],
      "source": [
        "tf_serialize_example(f0,f1,f2,f3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CrFZ9avE3HUF"
      },
      "source": [
        "将此函数应用于数据集中的每个元素："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VDeqYVbW3ww9"
      },
      "outputs": [],
      "source": [
        "serialized_features_dataset = features_dataset.map(tf_serialize_example)\n",
        "serialized_features_dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DlDfuh46bRf6"
      },
      "outputs": [],
      "source": [
        "def generator():\n",
        "  for features in features_dataset:\n",
        "    yield serialize_example(*features)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iv9oXKrcbhvX"
      },
      "outputs": [],
      "source": [
        "serialized_features_dataset = tf.data.Dataset.from_generator(\n",
        "    generator, output_types=tf.string, output_shapes=())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Dqz8C4D5cIj9"
      },
      "outputs": [],
      "source": [
        "serialized_features_dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p6lw5VYpjZZC"
      },
      "source": [
        "并将它们写入 TFRecord 文件："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vP1VgTO44UIE"
      },
      "outputs": [],
      "source": [
        "filename = 'test.tfrecord'\n",
        "writer = tf.data.experimental.TFRecordWriter(filename)\n",
        "writer.write(serialized_features_dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6aV0GQhV8tmp"
      },
      "source": [
        "### 读取 TFRecord 文件"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o3J5D4gcSy8N"
      },
      "source": [
        "您还可以使用 `tf.data.TFRecordDataset` 类来读取 TFRecord 文件。\n",
        "\n",
        "有关通过 `tf.data` 使用 TFRecord 文件的详细信息，请参见[此处](https://tensorflow.google.cn/guide/datasets#consuming_tfrecord_data)。\n",
        "\n",
        "使用 `TFRecordDataset` 对于标准化输入数据和优化性能十分有用。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6OjX6UZl-bHC"
      },
      "outputs": [],
      "source": [
        "filenames = [filename]\n",
        "raw_dataset = tf.data.TFRecordDataset(filenames)\n",
        "raw_dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6_EQ9i2E_-Fz"
      },
      "source": [
        "此时，数据集包含序列化的 `tf.train.Example` 消息。迭代时，它会将其作为标量字符串张量返回。\n",
        "\n",
        "使用 `.take` 方法仅显示前 10 条记录。\n",
        "\n",
        "注：在 `tf.data.Dataset` 上进行迭代仅在启用了 Eager Execution 时有效。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hxVXpLz_AJlm"
      },
      "outputs": [],
      "source": [
        "for raw_record in raw_dataset.take(10):\n",
        "  print(repr(raw_record))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W-6oNzM4luFQ"
      },
      "source": [
        "可以使用以下函数对这些张量进行解析。请注意，这里的 `feature_description` 是必需的，因为数据集使用计算图执行，并且需要以下描述来构建它们的形状和类型签名："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zQjbIR1nleiy"
      },
      "outputs": [],
      "source": [
        "# Create a description of the features.\n",
        "feature_description = {\n",
        "    'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0),\n",
        "    'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0),\n",
        "    'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''),\n",
        "    'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),\n",
        "}\n",
        "\n",
        "def _parse_function(example_proto):\n",
        "  # Parse the input `tf.Example` proto using the dictionary above.\n",
        "  return tf.io.parse_single_example(example_proto, feature_description)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gWETjUqhEQZf"
      },
      "source": [
        "或者，使用 `tf.parse example` 一次解析整个批次。使用 `tf.data.Dataset.map` 方法将此函数应用于数据集中的每一项："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6Ob7D-zmBm1w"
      },
      "outputs": [],
      "source": [
        "parsed_dataset = raw_dataset.map(_parse_function)\n",
        "parsed_dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sNV-XclGnOvn"
      },
      "source": [
        "使用 Eager Execution 在数据集中显示观测值。此数据集中有 10,000 个观测值，但只会显示前 10 个。数据会作为特征字典进行显示。每一项都是一个 `tf.Tensor`，此张量的 `numpy` 元素会显示特征的值："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x2LT2JCqhoD_"
      },
      "outputs": [],
      "source": [
        "for parsed_record in parsed_dataset.take(10):\n",
        "  print(repr(parsed_record))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Cig9EodTlDmg"
      },
      "source": [
        "在这里，`tf.parse_example` 函数会将 `tf.Example` 字段解压缩为标准张量。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jyg1g3gU7DNn"
      },
      "source": [
        "## Python 中的 TFRecord 文件"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3FXG3miA7Kf1"
      },
      "source": [
        "`tf.io` 模块还包含用于读取和写入 TFRecord 文件的纯 Python 函数。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CKn5uql2lAaN"
      },
      "source": [
        "### 写入 TFRecord 文件"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LNW_FA-GQWXs"
      },
      "source": [
        "接下来，将 10,000 个观测值写入文件 `test.tfrecord`。每个观测值都将转换为一条 `tf.Example` 消息，然后被写入文件。随后，您可以验证是否已创建 `test.tfrecord` 文件："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MKPHzoGv7q44"
      },
      "outputs": [],
      "source": [
        "# Write the `tf.Example` observations to the file.\n",
        "with tf.io.TFRecordWriter(filename) as writer:\n",
        "  for i in range(n_observations):\n",
        "    example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])\n",
        "    writer.write(example)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EjdFHHJMpUUo"
      },
      "outputs": [],
      "source": [
        "!du -sh {filename}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2osVRnYNni-E"
      },
      "source": [
        "### 读取 TFRecord 文件\n",
        "\n",
        "您可以使用 `tf.train.Example.ParseFromString` 轻松解析以下序列化张量："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U3tnd3LerOtV"
      },
      "outputs": [],
      "source": [
        "filenames = [filename]\n",
        "raw_dataset = tf.data.TFRecordDataset(filenames)\n",
        "raw_dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nsEAACHcnm3f"
      },
      "outputs": [],
      "source": [
        "for raw_record in raw_dataset.take(1):\n",
        "  example = tf.train.Example()\n",
        "  example.ParseFromString(raw_record.numpy())\n",
        "  print(example)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S0tFDrwdoj3q"
      },
      "source": [
        "## 演练：读取和写入图像数据"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rjN2LFxFpcR9"
      },
      "source": [
        "下面是关于如何使用 TFRecord 读取和写入图像数据的端到端示例。您将使用图像作为输入数据，将数据写入 TFRecord 文件，然后将文件读取回来并显示图像。\n",
        "\n",
        "如果您想在同一个输入数据集上使用多个模型，这种做法会很有用。您可以不以原始格式存储图像，而是将图像预处理为 TFRecord 格式，然后将其用于所有后续的处理和建模中。\n",
        "\n",
        "首先，让我们下载雪中的猫的[图像](https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg)，以及施工中的纽约威廉斯堡大桥的[照片](https://upload.wikimedia.org/wikipedia/commons/f/fe/New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5Lk2qrKvN0yu"
      },
      "source": [
        "### 提取图像"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3a0fmwg8lHdF"
      },
      "outputs": [],
      "source": [
        "cat_in_snow  = tf.keras.utils.get_file('320px-Felis_catus-cat_on_snow.jpg', 'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')\n",
        "williamsburg_bridge = tf.keras.utils.get_file('194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7aJJh7vENeE4"
      },
      "outputs": [],
      "source": [
        "display.display(display.Image(filename=cat_in_snow))\n",
        "display.display(display.HTML('Image cc-by: &lt;a \"href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg\"&gt;Von.grzanka&lt;/a&gt;'))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KkW0uuhcXZqA"
      },
      "outputs": [],
      "source": [
        "display.display(display.Image(filename=williamsburg_bridge))\n",
        "display.display(display.HTML('&lt;a \"href=https://commons.wikimedia.org/wiki/File:New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg\"&gt;From Wikimedia&lt;/a&gt;'))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VSOgJSwoN5TQ"
      },
      "source": [
        "### 写入 TFRecord 文件"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Azx83ryQEU6T"
      },
      "source": [
        "和以前一样，将特征编码为与 `tf.Example` 兼容的类型。这将存储原始图像字符串特征，以及高度、宽度、深度和任意 `label` 特征。后者会在您写入文件以区分猫和桥的图像时使用。将 `0` 用于猫的图像，将 `1` 用于桥的图像："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kC4TS1ZEONHr"
      },
      "outputs": [],
      "source": [
        "image_labels = {\n",
        "    cat_in_snow : 0,\n",
        "    williamsburg_bridge : 1,\n",
        "}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c5njMSYNEhNZ"
      },
      "outputs": [],
      "source": [
        "# This is an example, just using the cat image.\n",
        "image_string = open(cat_in_snow, 'rb').read()\n",
        "\n",
        "label = image_labels[cat_in_snow]\n",
        "\n",
        "# Create a dictionary with features that may be relevant.\n",
        "def image_example(image_string, label):\n",
        "  image_shape = tf.image.decode_jpeg(image_string).shape\n",
        "\n",
        "  feature = {\n",
        "      'height': _int64_feature(image_shape[0]),\n",
        "      'width': _int64_feature(image_shape[1]),\n",
        "      'depth': _int64_feature(image_shape[2]),\n",
        "      'label': _int64_feature(label),\n",
        "      'image_raw': _bytes_feature(image_string),\n",
        "  }\n",
        "\n",
        "  return tf.train.Example(features=tf.train.Features(feature=feature))\n",
        "\n",
        "for line in str(image_example(image_string, label)).split('\\n')[:15]:\n",
        "  print(line)\n",
        "print('...')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2G_o3O9MN0Qx"
      },
      "source": [
        "请注意，所有特征现在都存储在 `tf.Example` 消息中。接下来，函数化上面的代码，并将示例消息写入名为 `images.tfrecords` 的文件："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qcw06lQCOCZU"
      },
      "outputs": [],
      "source": [
        "# Write the raw image files to `images.tfrecords`.\n",
        "# First, process the two images into `tf.Example` messages.\n",
        "# Then, write to a `.tfrecords` file.\n",
        "record_file = 'images.tfrecords'\n",
        "with tf.io.TFRecordWriter(record_file) as writer:\n",
        "  for filename, label in image_labels.items():\n",
        "    image_string = open(filename, 'rb').read()\n",
        "    tf_example = image_example(image_string, label)\n",
        "    writer.write(tf_example.SerializeToString())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yJrTe6tHPCfs"
      },
      "outputs": [],
      "source": [
        "!du -sh {record_file}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jJSsCkZLPH6K"
      },
      "source": [
        "### 读取 TFRecord 文件\n",
        "\n",
        "现在，您有文件 `images.tfrecords`，并可以迭代其中的记录以将您写入的内容读取回来。因为在此示例中您只需重新生成图像，所以您只需要原始图像字符串这一个特征。使用上面描述的 getter 方法（即 `example.features.feature['image_raw'].bytes_list.value[0]`）提取该特征。您还可以使用标签来确定哪个记录是猫，哪个记录是桥："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M6Cnfd3cTKHN"
      },
      "outputs": [],
      "source": [
        "raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')\n",
        "\n",
        "# Create a dictionary describing the features.\n",
        "image_feature_description = {\n",
        "    'height': tf.io.FixedLenFeature([], tf.int64),\n",
        "    'width': tf.io.FixedLenFeature([], tf.int64),\n",
        "    'depth': tf.io.FixedLenFeature([], tf.int64),\n",
        "    'label': tf.io.FixedLenFeature([], tf.int64),\n",
        "    'image_raw': tf.io.FixedLenFeature([], tf.string),\n",
        "}\n",
        "\n",
        "def _parse_image_function(example_proto):\n",
        "  # Parse the input tf.Example proto using the dictionary above.\n",
        "  return tf.io.parse_single_example(example_proto, image_feature_description)\n",
        "\n",
        "parsed_image_dataset = raw_image_dataset.map(_parse_image_function)\n",
        "parsed_image_dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0PEEFPk4NEg1"
      },
      "source": [
        "从 TFRecord 文件中恢复图像："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yZf8jOyEIjSF"
      },
      "outputs": [],
      "source": [
        "for image_features in parsed_image_dataset:\n",
        "  image_raw = image_features['image_raw'].numpy()\n",
        "  display.display(display.Image(data=image_raw))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "pL--_KGdYoBz"
      ],
      "name": "tfrecord.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
